import torch
import os
import shutil
from safetensors.torch import load_file, save_file
from safetensors import safe_open
import argparse
def convert_dataformat(bf16_file,fp16_file):
    if not os.path.exists(bf16_file):
        raise(f"{bf16_file} do not exist!")

    fp16_state_dict = {}
    with safe_open(bf16_file, framework="pt", device='cpu') as f:
        for key in f.keys():
            fp16_state_dict[key] = f.get_tensor(key).to(torch.float16)
    save_file(fp16_state_dict,fp16_file, metadata={"format": "pt"})
    print("convert successed:", fp16_file)

def convert_ckpt(src_ckpt_path, dst_ckpt_path):

    if not os.path.exists(dst_ckpt_path):
        os.makedirs(dst_ckpt_path)

    for root,dirs,files in os.walk(src_ckpt_path):
        if '.git' in dirs:
            dirs.remove('.git')
        for item in files:
            #print("check_filename: ",item)
            if item.endswith("safetensors"):
                bf16_file=os.path.join(root,item)
                fp16_file=os.path.join(dst_ckpt_path,item)
                convert_dataformat(bf16_file, fp16_file)
            else:
                ori_file_path = root+"/"+item
                dst_file_path = dst_ckpt_path+"/"+item
                shutil.copyfile(ori_file_path,dst_file_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--src_ckpt_path', default='./src')
    parser.add_argument('--dst_ckpt_path', default='./dst')
    args = parser.parse_args()
    convert_ckpt(src_ckpt_path=args.src_ckpt_path, dst_ckpt_path=args.dst_ckpt_path)